"""
Phase 4: Supply-Demand Dynamics
================================
Constructs safe asset supply and demographic demand measures.
Tests supply-demand ratio → convenience yield, panel interactions,
and bridges to fiscal dominance findings.

Output: table4_supply_demand.md, table4b_fiscal_paradox.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_assets")
MULTILATERAL_DIR = PROJECT_DIR.parent / "multilateral"
sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

CONTROLS = ['rgdp_growth', 'inflation', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [dep_var] + regressors
    if dep_var not in df.columns:
        print(f"  [{label}] Dep var {dep_var} missing — skipping")
        return None
    regressors = [r for r in regressors if r in df.columns]

    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None

    names = feature_names or regressors
    if len(names) != len(regressors):
        names = regressors
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)

    print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, "
          f"R²={gls.r_squared:.4f}, rho={gls.rho:.3f}")

    results = {
        'label': label,
        'dep_var': dep_var,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<30} {gls.beta[i]:>8.4f} ({gls.se[i]:.4f}) {sig}")

    return results


def main():
    print("=" * 70)
    print("PHASE 4: Supply-Demand Dynamics")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "safe_asset_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = [c for c in CONTROLS if c in df.columns]

    # ── Construct additional supply-demand variables ──
    print("\n[1] Constructing supply-demand variables ...")

    # Country-level safe supply: govt_debt_gdp × safe_issuer
    if 'govt_debt_gdp' in df.columns:
        df['own_safe_supply'] = df['govt_debt_gdp'] * df['safe_issuer']

    # Excess demand: global_oadr × (1 - safe_supply_ratio)
    if 'global_oadr' in df.columns and 'safe_supply_ratio' in df.columns:
        df['safe_excess_demand'] = df['global_oadr'] * (1 - df['safe_supply_ratio'])
        print(f"  safe_excess_demand: mean={df['safe_excess_demand'].mean():.4f}")

    # Safe debt interaction
    if 'govt_debt_gdp' in df.columns:
        df['debt_x_safe'] = df['govt_debt_gdp'] * df['safe_issuer']
        for z in demo_vars:
            df[f'{z}_x_debt'] = df[z] * df['govt_debt_gdp']
            df[f'{z}_x_safe_debt'] = df[z] * df['debt_x_safe']

    # ================================================================
    # SECTION A: Safe Supply → Rates (Time-Series Channel)
    # ================================================================
    print("\n" + "─" * 60)
    print("A. Safe Supply Ratio → Rates")
    print("─" * 60)

    if 'safe_supply_ratio' in df.columns:
        # M1a: safe_supply_ratio → real 10y
        r = run_model(df, 'real_bond_10y', ['safe_supply_ratio'] + controls,
                      "M1a: safe_supply → 10y", ['safe_supply_ratio'] + controls)
        if r: all_results.append(r)

        # M1b: safe_supply_ratio + Z → real 10y (horse race)
        r = run_model(df, 'real_bond_10y',
                      demo_vars + ['safe_supply_ratio'] + controls,
                      "M1b: Z + safe_supply → 10y",
                      demo_vars + ['safe_supply_ratio'] + controls)
        if r: all_results.append(r)

    # M1c: global_oadr → 10y (demand-side)
    if 'global_oadr' in df.columns:
        r = run_model(df, 'real_bond_10y', ['global_oadr'] + controls,
                      "M1c: global_oadr → 10y", ['global_oadr'] + controls)
        if r: all_results.append(r)

    # M1d: excess demand → 10y
    if 'safe_excess_demand' in df.columns:
        r = run_model(df, 'real_bond_10y', ['safe_excess_demand'] + controls,
                      "M1d: excess_demand → 10y", ['safe_excess_demand'] + controls)
        if r: all_results.append(r)

    # ================================================================
    # SECTION B: Convenience Yield Proxies
    # ================================================================
    print("\n" + "─" * 60)
    print("B. Convenience Yield Proxies")
    print("─" * 60)

    # M2a: safe_supply → lending-govt spread
    if 'lending_govt_spread' in df.columns and 'safe_supply_ratio' in df.columns:
        r = run_model(df, 'lending_govt_spread',
                      ['safe_supply_ratio'] + controls,
                      "M2a: safe_supply → conv yield",
                      ['safe_supply_ratio'] + controls)
        if r: all_results.append(r)

    # M2b: Z + safe_supply → lending-govt spread
    if 'lending_govt_spread' in df.columns:
        extra = ['safe_supply_ratio'] if 'safe_supply_ratio' in df.columns else []
        r = run_model(df, 'lending_govt_spread',
                      demo_vars + extra + controls,
                      "M2b: Z → conv yield",
                      demo_vars + extra + controls)
        if r: all_results.append(r)

    # ================================================================
    # SECTION C: Safe Issuer × Debt Interactions
    # ================================================================
    print("\n" + "─" * 60)
    print("C. Safe Issuer × Debt Interactions (He-Krishnamurthy-Milbradt)")
    print("─" * 60)

    # Does more safe debt = safer when demand is high?
    if 'govt_debt_gdp' in df.columns:
        # M3a: debt × safe_issuer → sovereign spread
        if 'sovereign_spread' in df.columns:
            regs = ['govt_debt_gdp', 'safe_issuer', 'debt_x_safe'] + controls
            r = run_model(df, 'sovereign_spread', regs,
                          "M3a: debt×safe → spread", regs)
            if r: all_results.append(r)

        # M3b: Z × debt → 10y bond
        z_debt_vars = [f'{z}_x_debt' for z in demo_vars if f'{z}_x_debt' in df.columns]
        if z_debt_vars:
            regs = demo_vars + ['govt_debt_gdp'] + z_debt_vars + controls
            r = run_model(df, 'real_bond_10y', regs, "M3b: Z×debt → 10y", regs)
            if r: all_results.append(r)

        # M3c: Z × safe_debt → 10y (triple interaction)
        z_safe_debt = [f'{z}_x_safe_debt' for z in demo_vars if f'{z}_x_safe_debt' in df.columns]
        if z_safe_debt:
            regs = demo_vars + ['govt_debt_gdp', 'safe_issuer'] + z_safe_debt + controls
            r = run_model(df, 'real_bond_10y', regs,
                          "M3c: Z×safe_debt → 10y", regs)
            if r: all_results.append(r)

    # ================================================================
    # SECTION D: Fiscal Paradox (Bridge to fiscal_dominance paper)
    # ================================================================
    print("\n" + "─" * 60)
    print("D. Fiscal Paradox: Expenditure-Revenue Asymmetry for Safe Issuers")
    print("─" * 60)

    safe = df[df['safe_issuer'] == 1].copy()
    nonsafe = df[df['safe_issuer'] == 0].copy()

    # M4a: Z → govt expenditure (safe issuers)
    if 'govt_expenditure_gdp' in safe.columns:
        r = run_model(safe, 'govt_expenditure_gdp', demo_vars + controls,
                      "M4a: Safe Z → expenditure", demo_vars + controls)
        if r: all_results.append(r)

    # M4b: Z → govt revenue (safe issuers)
    if 'govt_revenue_gdp' in safe.columns:
        r = run_model(safe, 'govt_revenue_gdp', demo_vars + controls,
                      "M4b: Safe Z → revenue", demo_vars + controls)
        if r: all_results.append(r)

    # M4c: Z → govt_debt_gdp (safe issuers — demographic fiscal pressure)
    if 'govt_debt_gdp' in safe.columns:
        r = run_model(safe, 'govt_debt_gdp', demo_vars + controls,
                      "M4c: Safe Z → debt/GDP", demo_vars + controls)
        if r: all_results.append(r)

    # M4d: Same for non-safe (comparison)
    if 'govt_debt_gdp' in nonsafe.columns:
        r = run_model(nonsafe, 'govt_debt_gdp', demo_vars + controls,
                      "M4d: Non-safe Z → debt/GDP", demo_vars + controls)
        if r: all_results.append(r)

    # ================================================================
    # SECTION E: N_safe_issuers → Rates (Shrinking Supply)
    # ================================================================
    print("\n" + "─" * 60)
    print("E. Number of Safe Issuers Over Time")
    print("─" * 60)

    if 'n_safe_issuers' in df.columns:
        r = run_model(df, 'real_bond_10y',
                      ['n_safe_issuers'] + demo_vars + controls,
                      "M5: n_safe + Z → 10y",
                      ['n_safe_issuers'] + demo_vars + controls)
        if r: all_results.append(r)

    # ── Build results tables ──
    print("\n\nBuilding results tables ...")
    build_supply_demand_table(all_results)
    build_fiscal_paradox_table(all_results)

    print("\n" + "=" * 70)
    print("Phase 4 complete.")
    print("=" * 70)


def build_supply_demand_table(all_results):
    """Table 4: Supply-demand results."""
    if not all_results:
        print("  No results to tabulate.")
        return

    key_vars = ['Z_1', 'Z_2', 'Z_3', 'safe_supply_ratio', 'global_oadr',
                'safe_excess_demand', 'n_safe_issuers',
                'govt_debt_gdp', 'safe_issuer', 'debt_x_safe',
                'Z_1_x_debt', 'Z_1_x_safe_debt']

    md = ["# Table 4: Supply-Demand Dynamics\n"]

    md.append("## Model Summary\n")
    md.append("| Model | Dep Var | N | Countries | R² | ρ |")
    md.append("|---|---|---|---|---|---|")
    for r in all_results:
        if 'M4' not in r['label']:  # exclude fiscal section
            md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']} "
                      f"| {r['n_countries']} | {r['r_squared']:.3f} | {r['rho']:.3f} |")

    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in all_results:
        if 'M4' not in r['label']:
            for var in key_vars:
                ckey = f'coef_{var}'
                if ckey in r:
                    p = r[f'p_{var}']
                    md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                              f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")

    md.append(f"\n*Controls: {', '.join(CONTROLS)}*")
    md.append("*PanelGLS with AR(1) correction.*")

    out = TABLES_DIR / "table4_supply_demand.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


def build_fiscal_paradox_table(all_results):
    """Table 4b: Fiscal paradox for safe issuers."""
    fiscal = [r for r in all_results if 'M4' in r['label']]
    if not fiscal:
        print("  No fiscal results to tabulate.")
        return

    md = ["# Table 4b: Fiscal Paradox — Safe Issuers\n"]
    md.append("Aging safe issuers face rising fiscal pressure that may threaten their safety status.\n")

    md.append("| Model | Dep Var | N | Countries | R² |")
    md.append("|---|---|---|---|---|")
    for r in fiscal:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")

    md.append("\n## Demographic Coefficients\n")
    md.append("| Model | Z₁ | Z₁ p | Z₂ | Z₂ p |")
    md.append("|---|---|---|---|---|")
    for r in fiscal:
        z1 = r.get('coef_Z_1', np.nan)
        z1p = r.get('p_Z_1', np.nan)
        z2 = r.get('coef_Z_2', np.nan)
        z2p = r.get('p_Z_2', np.nan)
        md.append(f"| {r['label']} | {z1:.2f}{stars(z1p)} | {z1p:.4f} "
                  f"| {z2:.2f}{stars(z2p)} | {z2p:.4f} |")

    md.append("\n*Expenditure-revenue asymmetry from fiscal dominance paper: "
              "+10pp OADR → +12pp expenditure, +5pp revenue.*")

    out = TABLES_DIR / "table4b_fiscal_paradox.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


if __name__ == "__main__":
    main()
